Assignment 2 Music Generation¶
Installation & Imports¶
# Install required libraries
# !pip install torch pretty_midi matplotlib midi2audio librosa
# Imports
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import pretty_midi
import matplotlib.pyplot as plt
import pickle
import torch.nn.functional as F
import librosa
import collections
import math
from torch.utils.data import Dataset, DataLoader
from midi2audio import FluidSynth
from IPython.display import Audio, display
from matplotlib.ticker import MaxNLocator
Data Loading¶
# Load the pre‐serialized JSB Chorales dataset
with open("JSB-Chorales-dataset-master/jsb-chorales-quarter.pkl", "rb") as f:
data = pickle.load(f, encoding="latin1")
# We’ll work with the training split here, you can also access 'valid' and 'test'
chorales = data["train"]
print(f"Loaded {len(chorales)} training chorales.")
print("Sample:", chorales[0][:5])
Loaded 229 training chorales. Sample: [(60, 72, 79, 88), (72, 79, 88), (67, 70, 76, 84), (69, 77, 86), (67, 70, 79, 88)]
Dataset Context¶
The JSB Chorales dataset consists of 382 four-part harmonized chorales by J.S. Bach. It is widely used in symbolic music modeling and has been curated to support machine learning tasks. We use the version released by Zhuang et al., which is represented as a sequence of four‐voice chord events (soprano, alto, tenor, bass), quantized to quarter‐note durations.
Instead of modeling only the soprano line, we now build a polyphonic model that learns full four‐voice chorales in parallel. At each time step, the model will predict an entire 4‐tuple of MIDI pitches (or rests) for all voices simultaneously.
Preprocessing Steps¶
Extract four‐voice chord tuples
- For each chorale, read each 4‐element chord event (one MIDI pitch per voice).
- Skip any chord where all four voices are rests (
-1, -1, -1, -1). - Drop any chorale that has fewer than 10 valid chords.
Build a chord vocabulary
- Collect the set of all unique 4‐tuples (soprano, alto, tenor, bass) across the training split.
- Map each unique chord‐tuple to a distinct integer index.
Tokenize each chorale as a sequence of chord‐indices
- Convert each 4‐tuple in a chorale to its index in the chord vocabulary.
- Discard any chord not found in the vocabulary (e.g., if it only appeared in validation/test).
Prepare sequence‐to‐sequence training pairs
- Slide a fixed‐length window (e.g., 32 chords) over each tokenized chord sequence.
- For each window, the input is the first 32 chord‐indices, and the target is the next 32 chord‐indices (shifted by one).
Build
ChordSequenceDatasetandDataLoader- Wrap the tokenized sequences of indices in a PyTorch
Datasetthat returns(input_seq, target_seq)pairs. - Use a
DataLoaderwith a suitable batch size (e.g., 64) to feed the LSTM.
- Wrap the tokenized sequences of indices in a PyTorch
After these steps, we feed full four‐voice chord sequences into our MusicRNN model so that at each step it learns to predict a 4‐voice chord rather than a single monophonic melody.
# We build a sequence of 4‐tuples for all 4 harmonies: soprano, alto, tenor, bass.
# We skip any chord that is all rests (-1 in every voice), and drop very short chorales.
chord_seqs = []
for chorale in chorales:
chord_list = []
for chord in chorale:
# Chord is either a list/tuple of length 4, or -1 for a complete rest
if isinstance(chord, (list, tuple)) and len(chord) == 4:
# Convert any numpy types to int and keep the 4‐tuple as is:
chord_tuple = (int(chord[0]), int(chord[1]), int(chord[2]), int(chord[3]))
# If the chord is NOT four rests, we keep it. (If all four voices are -1, skip.)
if chord_tuple != (-1, -1, -1, -1):
chord_list.append(chord_tuple)
# Only keep chorales longer than 10 chords
if len(chord_list) > 10:
chord_seqs.append(chord_list)
print(f"Extracted {len(chord_seqs)} four‐voice sequences.")
print("Example chord‐sequence (first 5 chords):", chord_seqs[0][:5])
Extracted 229 four‐voice sequences. Example chord‐sequence (first 5 chords): [(60, 72, 79, 88), (67, 70, 76, 84), (67, 70, 79, 88), (65, 72, 81, 89), (65, 72, 81, 89)]
Vocabulary & Tokenization¶
# Build a set of all unique 4‐tuples (chords) in the training split.
all_chords = sorted({tuple(chord) for seq in chord_seqs for chord in seq})
# Map each chord‐tuple to a unique integer index
chord_to_idx = {chord: i for i, chord in enumerate(all_chords)}
idx_to_chord = {i: chord for chord, i in chord_to_idx.items()}
vocab_size = len(chord_to_idx)
# Convert each chord‐tuple sequence into a list of indices
tokenized_chord_seqs = [[chord_to_idx[ch] for ch in seq] for seq in chord_seqs]
print("Four‐voice chord vocabulary size:", vocab_size)
print("Tokenized example (first 10 chord‐indices):", tokenized_chord_seqs[0][:10])
Four‐voice chord vocabulary size: 2113 Tokenized example (first 10 chord‐indices): [736, 1496, 1502, 1338, 1338, 537, 1697, 1634, 1704, 1445]
Dataset Class¶
# Create Dataset class for LSTM training.
# Takes tokenized melody sequences and splits into
# fixed-length input-output pairs.
class ChordSequenceDataset(Dataset):
def __init__(self, token_chord_seqs, seq_len=32):
super().__init__()
self.samples = []
# Slide a window of length seq_len over each chord‐token sequence
for seq in token_chord_seqs:
for i in range(len(seq) - seq_len):
x = seq[i : i + seq_len] # input: a sequence of chord‐indices
y = seq[i + 1 : i + seq_len + 1] # target: next‐chord at each step
self.samples.append((x, y))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
x, y = self.samples[idx]
# Return LongTensors of shape (seq_len,) of chord‐indices
return torch.tensor(x, dtype=torch.long), \
torch.tensor(y, dtype=torch.long)
DataLoader Preparation¶
# Create batches of (input, target) pairs for training.
seq_len = 32 # length of each input sequence (tries to predict 32 next notes)
batch_size = 64 # number of sequences per batch (process 64 input-output pairs at a time)
# Create dataset and dataloader
dataset = ChordSequenceDataset(tokenized_chord_seqs, seq_len=seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(f"Total training chord‐sequences: {len(dataset)}")
Total training chord‐sequences: 5186
Training Model¶
class MusicRNN(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, seq_len=32):
super(MusicRNN, self).__init__()
# Embedding now maps each chord‐index to a dense vector
self.embedding = nn.Embedding(vocab_size, embedding_dim)
# Positional embeddings add information about each timestep's position
self.position_embed = nn.Embedding(seq_len, embedding_dim)
# LSTM stack: processes the embedded sequence, with dropout between layers
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True, # input/output tensors have shape (batch, seq, feature)
dropout=0.2 # dropout on outputs of all layers except the last
)
self.norm = nn.LayerNorm(hidden_dim) # LayerNorm stabilizes the activations before the final layers
self.dropout = nn.Dropout(0.3) # Dropout after LSTM to reduce overfitting
self.fc = nn.Linear(hidden_dim, vocab_size) # Final linear layer maps hidden states to vocabulary logits
def forward(self, x):
batch_size, seqlen = x.size()
# Create a tensor of positions [0, 1, ..., seq_len-1] for each example
positions = (torch.arange(seqlen, device=x.device)
.unsqueeze(0)
.expand(batch_size, seqlen))
embeddings = self.embedding(x) + self.position_embed(positions)
out, _ = self.lstm(embeddings)
out = self.norm(out)
out = self.dropout(out)
logits = self.fc(out) # shape: (batch_size, seqlen, vocab_size)
return logits
def train_rnn(model, dataloader, vocab_size, num_epochs=10, lr=0.001,
device="cuda" if torch.cuda.is_available() else "cpu"):
"""
Train the MusicRNN on the provided dataloader.
model: instance of MusicRNN
dataloader: yields (input_batch, target_batch) pairs
vocab_size: size of the token vocabulary for loss calculation
"""
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
# Scheduler reduces LR by 0.5 if validation loss hasn't improved for 2 epochs
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=2)
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
for xb, yb in dataloader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
# Forward pass: get logits of shape (batch, seq_len, vocab_size)
logits = model(xb)
# Compute cross-entropy loss across all timesteps
loss = loss_fn(
logits.view(-1, vocab_size), # (batch*seq_len, vocab_size)
yb.view(-1) # (batch*seq_len,)
)
# Backward pass and gradient clipping
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update model parameters
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}")
# Step the scheduler with the average training loss
scheduler.step(avg_loss)
# Trains the Model for 10 epochs
model = MusicRNN(vocab_size=vocab_size, seq_len=32)
train_rnn(model, dataloader, vocab_size, num_epochs=10)
Epoch 1/10 | Loss: 5.1666 Epoch 2/10 | Loss: 2.5353 Epoch 3/10 | Loss: 1.3446 Epoch 4/10 | Loss: 0.8037 Epoch 5/10 | Loss: 0.5484 Epoch 6/10 | Loss: 0.4196 Epoch 7/10 | Loss: 0.3458 Epoch 8/10 | Loss: 0.2947 Epoch 9/10 | Loss: 0.2619 Epoch 10/10 | Loss: 0.2408
Sampling from the trained LSTM¶
# 3 samples: (A) a random 4-note prefix, (B) a single-note "cold" start, or (C) a very short seed.
def sample_diverse(
model,
tokenized_seqs,
max_length=64,
prefix_type="random_short", # "random_short", "single", or "fixed"
fixed_prefix=None, # only used if prefix_type=="fixed"
prefix_len=4,
first_steps_temp=2.0, # high temp for initial steps
normal_temp=1.0,
top_k=5,
top_p=0.8,
device="cuda" if torch.cuda.is_available() else "cpu"
):
"""
prefix_type:
- "fixed": uses fixed_prefix (list of IDs)
- "random_short": picks a random melody and takes prefix_len tokens
- "single": starts from 1 random token
"""
model.eval().to(device)
# Pick our seed
if prefix_type == "fixed":
assert fixed_prefix is not None
prefix = fixed_prefix
elif prefix_type == "random_short":
seq = random.choice(tokenized_seqs)
prefix = seq[:prefix_len]
elif prefix_type == "single":
prefix = [ random.choice(tokenized_seqs)[0] ]
else:
raise ValueError("bad prefix_type")
generated = prefix[:]
input_seq = torch.tensor([generated], device=device)
def filter_logits(logits):
from torch.nn.functional import softmax
logits = logits.clone()
# Top-k
if top_k>0:
kth = torch.topk(logits, top_k)[0][-1]
logits[logits < kth] = -1e9
# Top-p
if top_p>0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum = softmax(sorted_logits, dim=-1).cumsum(dim=-1)
mask = cum > top_p
mask[...,1:] = mask[...,:-1].clone()
mask[...,0] = False
logits[ sorted_idx[mask] ] = -1e9
return logits
for i in range(max_length - len(prefix)):
# Choose temperature
temp = first_steps_temp if i < len(prefix) else normal_temp
seq_len = model.position_embed.num_embeddings
inp = input_seq[:, -seq_len:]
logits = model(inp)[0, -1, :] / temp
filt = filter_logits(logits)
probs = F.softmax(filt, dim=-1)
nxt = torch.multinomial(probs, 1).item()
generated.append(nxt)
input_seq = torch.tensor([generated], device=device)
return generated
# Try all three strategies:
gens = {}
gens["A_random4"] = sample_diverse(
model,
tokenized_chord_seqs,
prefix_type="random_short",
prefix_len=4
)
gens["B_single"] = sample_diverse(
model,
tokenized_chord_seqs,
prefix_type="single"
)
gens["C_fixed4"] = sample_diverse(
model,
tokenized_chord_seqs,
prefix_type="fixed",
fixed_prefix=tokenized_chord_seqs[0][:4] # first 4 chords of the first chorale
)
# Now map each generated chord-index sequence back to actual 4-tuples
chord_sequences = {
name: [idx_to_chord[idx] for idx in seq]
for name, seq in gens.items()
}
# 3 different generated strategies
generated_chords = chord_sequences["A_random4"]
generated_chords2 = chord_sequences["B_single"]
generated_chords3 = chord_sequences["C_fixed4"]
Save original & generated as MIDI and convert to WAV for listening¶
# Helper function to write a list of MIDI pitches to a .mid file
# with all four voice‐notes in parallel at each time step.
def save_four_voice_midi(chord_seq, filename="polyphonic_output.mid", note_duration=0.5):
pm = pretty_midi.PrettyMIDI()
instr = pretty_midi.Instrument(program=0) # single piano instrument
current_time = 0.0
for item in chord_seq:
if isinstance(item, tuple):
chord_tuple = item
else:
# assume 'item' is an index
chord_tuple = idx_to_chord[item]
for pitch in chord_tuple:
if pitch != -1:
note = pretty_midi.Note(
velocity=100,
pitch=pitch,
start=current_time,
end=current_time + note_duration
)
instr.notes.append(note)
current_time += note_duration
pm.instruments.append(instr)
pm.write(filename)
# Convert chord-indices → write a four-voice MIDI
save_four_voice_midi(generated_chords, filename="generated_chords_A.mid")
save_four_voice_midi(generated_chords2, filename="generated_chords_B.mid")
save_four_voice_midi(generated_chords3, filename="generated_chords_C.mid")
# Convert original 4-voice (first 64 chords) and each generated version to WAV
save_four_voice_midi(tokenized_chord_seqs[0][:64], filename="original_chords.mid")
fs = FluidSynth("FluidR3_GM.sf2")
fs.midi_to_audio("original_chords.mid", "original_chords.wav")
fs.midi_to_audio("generated_chords_A.mid", "generated_A.wav")
fs.midi_to_audio("generated_chords_B.mid", "generated_B.wav")
fs.midi_to_audio("generated_chords_C.mid", "generated_C.wav")
# Play back and display audio in notebook
print("🎹 Original four-voice:")
display(Audio("original_chords.wav"))
print("🎹 Generated (A: random4):")
display(Audio("generated_A.wav"))
print("🎹 Generated (B: single-chord cold start):")
display(Audio("generated_B.wav"))
print("🎹 Generated (C: fixed4 prefix):")
display(Audio("generated_C.wav"))
FluidSynth runtime version 2.3.5 Copyright (C) 2000-2024 Peter Hanappe and others. Distributed under the LGPL license. SoundFont(R) is a registered trademark of Creative Technology Ltd. Rendering audio to file 'original_chords.wav'.. FluidSynth runtime version 2.3.5 Copyright (C) 2000-2024 Peter Hanappe and others. Distributed under the LGPL license. SoundFont(R) is a registered trademark of Creative Technology Ltd. Rendering audio to file 'generated_A.wav'.. FluidSynth runtime version 2.3.5 Copyright (C) 2000-2024 Peter Hanappe and others. Distributed under the LGPL license. SoundFont(R) is a registered trademark of Creative Technology Ltd. Rendering audio to file 'generated_B.wav'.. FluidSynth runtime version 2.3.5 Copyright (C) 2000-2024 Peter Hanappe and others. Distributed under the LGPL license. SoundFont(R) is a registered trademark of Creative Technology Ltd. Rendering audio to file 'generated_C.wav'.. 🎹 Original four-voice:
🎹 Generated (A: random4):
🎹 Generated (B: single-chord cold start):